

import torch

import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from Modules import *
import numpy as np
from LinearModule_utils import *
from Activation_utils import *


#------------------------------------------------------------------#
#embedding structure [ signal, memory, position ] 
#I will introduce k blanks to separate different sub-sequences
#------------------------------------------------------------------#



#------------------------------------------------------------------#
#config contains the following important parameters: 
#config.signal_start : Start Index of current signal embeddings (0 always)
#config.signal_end : End Index of current signal
#config.memory_end : End Index of memorized embeddings (from a previous layer)
#config.position_start : Start index of one-hot position embeddings
#config.seq_length : Sequence length of the smaller model that we are trying to simulate
#config.blank_identifier : Index containing Identifiers for blank token
#config.num_blanks : Number of blanks to separate the sub-sequences
#config.num_attention_heads : Number of attention heads
#config.scale_embeddings : A scale to initialize different query, key matrices
#config.inner_lr : Inner learning rate to simulate sgd 
#config.gate_scale: Scale to use inside gates  
#------------------------------------------------------------------# 


#------------------------------------------------------------------#
#LinearForward module computes Wx * Act (Ux) 
#arguments: input dimension, output dimension    
#output: an attention layer
#------------------------------------------------------------------#

class GLUForward(nn.Module):
    def __init__(self, config, din, use_softmax, memory_index):
        super(GLUForward, self).__init__()
        
        self.config = config
        self.din    = din
        
        assert memory_index == -1 or self.config.hidden_size - memory_index >= 3*self.din, \
              "Not enough space for memory"
        assert memory_index == -1 or memory_index >= self.din, \
              "Interacts with signal"
        
        self.memory_index = memory_index
        

        self.linear = LinearForward(config, din=din, dout=din, use_softmax=use_softmax, memory_index=-1)
        self.activation = ActivationForward (config, din, memory_index=-1)
        self.gates = Gates (config)
        
        #Initialize Gates
        #Ignore the changes for the blanks!
        #w, u, v, w_bias, u_bias, v_bias
        w = torch.zeros((1, 2*config.hidden_size))
        u = torch.zeros((1, 2*config.hidden_size))
        v = torch.zeros((1, 2*config.position_dim))
        w_bias = torch.zeros(2)
        u_bias = torch.zeros(2)
        v_bias = torch.zeros(2)

        #Input Gate is 1 on blanks and 0 for non-blanks
        v [0, config.seq_length: config.position_dim] = config.gate_scale * torch.ones(config.num_blanks)

        #Change Gate is 0 on blanks and 1 for non-blanks
        v [0, config.position_dim+config.seq_length: 2*config.position_dim] = -config.gate_scale * torch.ones(config.num_blanks)
        v_bias [1] += config.gate_scale

        self.gates.initialize_weights (w, u, v, w_bias, u_bias, v_bias)
        
        
        
    def forward(self, hidden_states, position_states, activation_weights):    
        
        
        linear_output = self.linear.forward(hidden_states, position_states)
        
        #if activation_weights is not None:
        input_states = torch.cat( [activation_weights, hidden_states[:, self.config.num_blanks:] ], axis=1 )
        activation_input = self.linear.forward(input_states, position_states)
        activation_output = self.activation (activation_input, position_states)

        gated_output = self.gates( hidden_states, linear_output * activation_output, position_states )
        
            
        if self.memory_index != -1:
            #store Wx, Ux, x
            gated_output[:, self.config.num_blanks:, self.memory_index: self.memory_index + self.din] += linear_output[:, self.config.num_blanks:, :self.din]
            
            gated_output[:, self.config.num_blanks:, self.memory_index + self.din: self.memory_index + 2*self.din] += activation_input[:, self.config.num_blanks:, :self.din]
            
            gated_output[:, self.config.num_blanks:, self.memory_index + 2*self.din: self.memory_index + 3*self.din] += hidden_states[:, self.config.num_blanks:, :self.din]
            
        
        return gated_output
    
    
#First din coordinates contain \nabla y    
#Memory has [WX, UX, x]

class GLUBackward_Descent(nn.Module):
    def __init__(self, config, din, use_softmax, memory_index=-1, debug_zero=False, projection_matrix=None, retain_nablay=False):
        super(GLUBackward_Descent, self).__init__()
        
        self.config = config
        self.din = din
        self.memory_index = memory_index
        
        self.activation_backward = ActivationBackward (config, \
                                                       din=din, \
                                                       input_projection=None, \
                                                       projection_matrix=None, \
                                                       memory_index=memory_index + self.din, \
                                                       retain_og_act=True,
                                                      )
        
        
        self.linearback_descent = Linear_Descent_Backward (config, \
                                                           din=din, \
                                                           dout=din, \
                                                           use_softmax=use_softmax, \
                                                           memory_index=memory_index + 2*self.din, \
                                                           debug_zero=debug_zero, \
                                                           projection_matrix=projection_matrix, \
                                                           retain_nablay=retain_nablay, \
                                                          )
        
        
        
    def forward(self, hidden_states, position_states, attention_mask=None, activation_weights=None):    
        
        #First compute \nabla y * Gates' (Ux)
        #Ux has been replaced by Gates (Ux) in the memory !
        
        
        #gates_gradient = self.activation_backward (hidden_states, position_states)
        one_vec = torch.ones_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
        one_vec [:, self.config.num_blanks:, :self.din] = hidden_states [:, self.config.num_blanks:, self.memory_index: self.memory_index + self.din]
        act_output = self.activation_backward(hidden_states * one_vec, position_states)
        #print (act_output[0, self.config.num_blanks, :self.din])
        #return act_output, None
        
       
        #Compute Ux * Gates (Ux)
        one_vec = torch.ones_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
        one_vec [:, self.config.num_blanks:, :self.din] = act_output [:, self.config.num_blanks:, self.memory_index + self.din: self.memory_index + 2*self.din]
        
        
        
        #Now use linearback_descent
        W_gradient_backprop = self.linearback_descent(hidden_states * one_vec, position_states, attention_mask)
        #print ( W_gradient_backprop[0, self.config.num_blanks, :self.din])
        
        
        #Also backprop through the weights inside the gate?
        input_states = torch.cat( [activation_weights, act_output[:, self.config.num_blanks:] ], axis=1 )
        input_states[:, self.config.num_blanks:, self.memory_index+2*self.din: self.memory_index+3*self.din] += hidden_states[:, self.config.num_blanks:, self.memory_index+2*self.din: self.memory_index+3*self.din]
        
        
        U_gradient_backprop = self.linearback_descent(input_states, position_states, attention_mask)
        
        #finally, add W_gradient_backprop and U_gradient_backprop in the non blank region
        W_gradient_backprop[:, self.config.num_blanks:] += U_gradient_backprop[:, self.config.num_blanks:]
        
        return W_gradient_backprop, U_gradient_backprop
        
       
    

